import numpy as np
import pylab as pl

# # # # # # # # # # # # # # # # # # # # # # # # # #
# This script reads the log files of the mini_WTA
# experiments and saves the data in a .npz archive.
# # # # # # # # # # # # # # # # # # # # # # # # # #

# # # # # # # # 
# P A R A M S #
# # # # # # # # 


path = "data/ANN runs/"
fname = "TR47.log"

inname = raw_input("Enter filename: ") # user input filename.

# Set input filename
fname = inname[:-3] + "log"         # Clear file extension given by user and use .npz instead.

K = 2       # num neurons
N = 4       # num inputs

T_per_run = 499     # number of t in a run; multiple runs can be in the same file...
                    # ... If last run in logfile is X, this number should be X+1.

eol_delimiters = ("\n", "\n\r", "\r\n", " \n")
out_fname = path + fname[:-4] + ".npz"

# # # # # # # # # # # # # # #
# E N D   O F   P A R A M S #
# # # # # # # # # # # # # # #

# Read raw data
print " > Read data from file '%s'." % (path + fname)
f = open(path + fname)
data = f.readlines()
f.close()

# remove comments and empty lines
for i in reversed(range(len(data))):
    l = data[i]
    if (l in eol_delimiters) or (l[:3] == ">>>") or (l[:7] == "  ====="):
        data.pop(i)

# count T
T = np.sum([1 for l in data if l[:2] == "t:"])
print " > Number of trials: T = %d." % T

# Generate arrays for storage
Y = np.zeros((T,N), dtype=int) #Holds pattern presented at time t to each cell N.
Z = np.zeros(T, dtype=int)
U = np.zeros((T,K), dtype=float)
B = np.zeros((T,K), dtype=float)
V = np.zeros((T,K,N), dtype=float)

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 
# Parse the data
# Note that in the log file:
#   (a) the membrane potentials are printed prior to the time stamp t
#   (b) The biases at time t are the values AFTER the update, so we
#       write them to B[t+1] to match the weights.
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 

lc = 0       # line counter
print " > Read dataset:"
for t in xrange(T):
    print "  > Time: %d" % t
    # Membrane potentials
    for k in xrange(K):
        offset = len("k=0; Vk = ")  # data position
        l = data[lc].strip()        # remove trailing white characters
        U[t,k] = float(l[offset:])  # read
        lc += 1
    # Time stamp
    offset = len("t: ")
    l = data[lc].strip()
    #assert (t % T_per_run) == int(l[offset:]), " > ERROR: Time marker does not match!"
    lc +=1
    # input, winner, biases after learning
    l = data[lc].strip()
    for i in xrange(N):             # input
        Y[t,i] = int(l[i])
    offset = N + 3                  # jump over 0101\t|\t
    Z[t] = int(l[offset])           # winner
    offset = N + 3 + 1 + 3          # jump over 0101\t|\t1\t|\t
    for k in xrange(K):
        if t == T-1: break                               # skip in last run
        end_idx = l[offset:].find("\t")                 # determine end of value position
        if end_idx == -1: end_idx = 1000                # catch last value
        B[t+1,k] = float(l[offset:offset+end_idx]) / 1000.     # read value and transform to float
        offset += end_idx + 1                           # jump to next pos
    lc +=1
    # Synapse array
    for k in xrange(K):
        l = data[lc].strip()
        offset = len("Neuron 0:\t")
        for i in xrange(N):
            end_idx = l[offset:].find("\t")                 # determine end of value position
            if end_idx == -1: end_idx = 1000                # catch last value
            V[t,k,i] = float(l[offset:offset+end_idx]) / 1000.   # read value and transform to float
            offset += end_idx + 1                           # jump to next pos
        lc += 1

print " > Data extracted."

kwargs = dict(K=K, N=N, T=T, Y=Y, Z=Z, U=U, B=B, V=V)
print " > Save extracted data to '%s'." % out_fname
np.savez_compressed(out_fname, **kwargs)
print " > Done."